# ------------------------------------------------------------------------------
# Performs IV regression using leave-out test rate as instrument
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 02_iv_regression_tbls.sh
# ------------------------------------------------------------------------------

# Libraries --------------------------------------------------------------------
library(here)
library(yaml)
library(data.table)
library(tidyverse)
library(testit) # assert()
library(glue) # glue strings
library(lfe)
library(OneR) # bin
library(stargazer)

temp <- here::here("code", "05_natural_experiment", "temp")
overnight_lab <- ""
u <- modules::use(here::here("lib", "util.R"))

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here::here("lib", "filepaths.yml"))
longterm_outcomes <- readRDS(paths$analysis$longterm_outcomes)
cohort <- readRDS(file.path(temp, "shift_test_rates.rds")) %>%
  mutate(
    yhat = p__ensemble__stent_or_cabg_010_day__tested,
    risk_dec_10 = tile_stent_or_cabg_010_tested == 10,
    risk_decile = tile_stent_or_cabg_010_tested %>% factor,
    risk_quintile = tile_stent_or_cabg_005_tested %>% factor,
    risk_quint_5 = tile_stent_or_cabg_005_tested == 5
    # uncomment to produce tables in terms of SD of t-rate
    # shift_12_trate_leaveout = shift_12_trate_leaveout/sd(shift_12_trate_leaveout)
  ) %>%
  filter(split != "val") %>%
  u$safe_left_join(longterm_outcomes) %>%
  mutate(
    shift_times_yhat = shift_12_trate_leaveout*yhat,
    shift_times_risk_dec_10 = shift_12_trate_leaveout*risk_dec_10,
    shift_times_risk_quint_5 = shift_12_trate_leaveout*risk_quint_5
  ) %>%
  mutate(
    macetrop_31_to_365_pos_excl = macetrop_31_to_365_pos & !macetrop_030_pos,
  )

adverse_outcomes <- c(
  "macetrop_31_to_365_pos_excl",
  "death_31_to_365", "death_365_day"
)
for(outcome in adverse_outcomes){
  mean_overall <- mean(cohort[[outcome]])
  mean_top_quint <- mean(
    filter(cohort, tile_stent_or_cabg_005_tested == 5)[[outcome]]
  )
  message("-----------")
  message("Outcome: ", outcome)
  message("Overall rate: ", mean_overall)
  message("Rate in top quintile: ", mean_top_quint)
  message("-----------")
}

keep_encs <- cohort %>%
  .[order(.$t0),] %>%
  group_by(ptid) %>%
  summarize(ed_enc_id = ed_enc_id[1]) %>%
  ungroup %>%
  .[["ed_enc_id"]]

cohort <- cohort %>% filter(ed_enc_id %in% keep_encs)

message("SD of shift test rate: ", sd(cohort[["shift_12_trate_leaveout"]]))
message("Mean y-hat: ", mean(cohort$yhat))
top_dec <- filter(cohort, risk_dec_10)
top_quint <- filter(cohort, tile_stent_or_cabg_005_tested == 5)
message("Mean y-hat in top decile: ", mean(top_dec$yhat))
message("Mean y-hat in top quintile: ", mean(top_quint$yhat))
message("Mean test rate: ", mean(cohort$test_010_day))
message("Mean test rate in top decile of risk: ", mean(top_dec$test_010_day))
message("Mean test rate in top quintile of risk: ", mean(top_quint$test_010_day))

# IV Regression ----------------------------------------------------------------
message("Fitting IV regression...")

risk_vars <- c("yhat")#, "risk_dec_10")
# risk_vars <- c("risk_dec_10")
omit_stats = c("f", "ser", "ll", "aic", "adj.rsq")
reg_list_stage1 <- list()
for(risk_var in risk_vars){
  reg_list <- list()
  reg_list_nointer = list()
  message("Risk variable: ", risk_var)

  # Step 1: w/o interaction
  message("Step 1: No Interaction")
  form_step1 <- reformulate(
    response = "test_010_day",
    termlabels = c(
      risk_var, "shift_12_trate_leaveout",
      "1 | hour + wday + week + year | 0 | ptid + shift_12"
    )
  )
  fit_step1 <- felm(form_step1, data = cohort)
  cohort[["step1_preds"]] <- fit_step1$fitted.values %>% as.numeric

  print(summary(fit_step1))
  # this is just to get conditional F stat from felm
  iv_form <- glue(
      "death_365_day ~ {risk_var} | hour + wday + week + year | (test_010_day ~ shift_12_trate_leaveout) | shift_12"# + ptid"
  ) %>%
    formula
  iv_fit <- felm(iv_form, data = cohort)

  message("Conditional F Stat: No Interaction")
  print(condfstat(iv_fit))

  # Step 1: with interaction
  message("Step 1: Interaction")
  form_step1_inter <- reformulate(
    response = "test_010_day",
    termlabels = c(
      glue("{risk_var}*shift_12_trate_leaveout"),
      "1 | hour + wday + week + year | 0 | shift_12" # + ptid"
    )
  )
  fit_step1_inter <- felm(form_step1_inter, data = cohort)
  cohort[["step1_preds_inter"]] <- fit_step1_inter$fitted.values %>% as.numeric
  print(summary(fit_step1_inter))

  message("SD step 1 preds (no interaction): ", sd(cohort$step1_preds))
  message("SD step 1 preds (with interaction): ", sd(cohort$step1_preds_inter))

  reg_list_stage1 <- c(reg_list_stage1, list(fit_step1), list(fit_step1_inter))
  for(test_var in c("shift_12_trate_leaveout")){
    message("Test var: ", test_var)
    # reg_list <- list()
    # reg_list_nointer <- list()
    test_var_inter <- ifelse(
      test_var == "step1_preds", "step1_preds_inter", test_var
    )

    reg_df <- copy(cohort)
    reg_df[["test_var"]] <- reg_df[[test_var]]
    reg_df[["test_var_inter"]] <- reg_df[[test_var_inter]]

    # cluster_vars_inter <- ifelse(test_var == "step1_preds", "shift_12", "0")
    cluster_vars_inter <- "0"
    message("Step 2: Adverse Outcomes")
    for(outcome in adverse_outcomes){

      form <- reformulate(
        response = outcome,
        termlabels = c(
          glue("{risk_var}*test_var_inter"),
          glue("1 | hour + wday + week + year | 0 | {cluster_vars_inter}")
        )
      )
      fit <- felm(form, data = reg_df)
      reg_list <- c(reg_list, list(fit))
      message("Interacted formula:")
      print(form)

      form_nointer <- reformulate(
        response = outcome,
        termlabels = c(
          risk_var, "test_var",
          "1| hour + wday + week + year | 0 | shift_12"
        )
      )
      fit_nointer <- felm(form_nointer, data = reg_df)
      reg_list_nointer <- c(reg_list_nointer, list(fit_nointer))
      message("Non-interacted formula:")
      print(form_nointer)
    }
  }
  message("Saving adverse outcome reg results for ", risk_var, "...")
  reg_tbl <- stargazer(reg_list, omit.stat = omit_stats, dep.var.caption = "", digits = 2)
  save_fp <- file.path(temp, glue("longterm_regs__adverse_interacted__{risk_var}.tex"))
  write(reg_tbl, save_fp, append = FALSE)

  reg_tbl_nointer <- stargazer(
    reg_list_nointer, omit.stat = omit_stats, dep.var.caption = "",
    digits = 2
  )
  save_fp <- file.path(temp, glue("longterm_regs__adverse__{risk_var}.tex"))
  write(reg_tbl_nointer, save_fp, append = FALSE)
}

# Yield Balance
message("Step 2: Yield in Tested")
tested <- filter(cohort, test_010_day)
yield_reg_list <- list()
for(interaction in c(FALSE, TRUE)){
  for(risk_var in risk_vars){
    risk_test_term <- case_when(
      interaction ~ glue("{risk_var} * shift_12_trate_leaveout"),
      !interaction ~ glue("{risk_var} + shift_12_trate_leaveout")
    )
    form <- reformulate(
      response = "stent_or_cabg_010_day",
      termlabels = c(
        risk_test_term,
        "1 | hour + wday + week + year | 0 | ptid + shift_12"
      )
    )
    fit <- felm(form, data = tested)
    yield_reg_list <- c(yield_reg_list, list(fit))
  }
}

yield_test_tbl <- stargazer(
  c(yield_reg_list, reg_list_stage1), omit.stat = omit_stats,
  dep.var.caption = ""
)
save_fp <- file.path(temp, "yield_test__on__shift_trate__regs.tex")
write(yield_test_tbl, save_fp)

message("Done.")
